{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "ba5f8741",
   "metadata": {},
   "source": [
    "# Custom LLM Agent (with a ChatModel)\n",
    "\n",
    "This notebook goes through how to create your own custom agent based on a chat model.\n",
    "\n",
    "An LLM chat agent consists of three parts:\n",
    "\n",
    "- PromptTemplate: This is the prompt template that can be used to instruct the language model on what to do\n",
    "- ChatModel: This is the language model that powers the agent\n",
    "- `stop` sequence: Instructs the LLM to stop generating as soon as this string is found\n",
    "- OutputParser: This determines how to parse the LLMOutput into an AgentAction or AgentFinish object\n",
    "\n",
    "\n",
    "The LLMAgent is used in an AgentExecutor. This AgentExecutor can largely be thought of as a loop that:\n",
    "1. Passes user input and any previous steps to the Agent (in this case, the LLMAgent)\n",
    "2. If the Agent returns an `AgentFinish`, then return that directly to the user\n",
    "3. If the Agent returns an `AgentAction`, then use that to call a tool and get an `Observation`\n",
    "4. Repeat, passing the `AgentAction` and `Observation` back to the Agent until an `AgentFinish` is emitted.\n",
    "    \n",
    "`AgentAction` is a response that consists of `action` and `action_input`. `action` refers to which tool to use, and `action_input` refers to the input to that tool. `log` can also be provided as more context (that can be used for logging, tracing, etc).\n",
    "\n",
    "`AgentFinish` is a response that contains the final message to be sent back to the user. This should be used to end an agent run.\n",
    "        \n",
    "In this notebook we walk through how to create a custom LLM agent."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fea4812c",
   "metadata": {},
   "source": [
    "## Set up environment\n",
    "\n",
    "Do necessary imports, etc."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9af9734e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.agents import Tool, AgentExecutor, LLMSingleActionAgent, AgentOutputParser\n",
    "from langchain.prompts import BaseChatPromptTemplate\n",
    "from langchain import SerpAPIWrapper, LLMChain\n",
    "from langchain.chat_models import ChatOpenAI\n",
    "from typing import List, Union\n",
    "from langchain.schema import AgentAction, AgentFinish, HumanMessage\n",
    "import re"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6df0253f",
   "metadata": {},
   "source": [
    "## Set up tool\n",
    "\n",
    "Set up any tools the agent may want to use. This may be necessary to put in the prompt (so that the agent knows to use these tools)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "becda2a1",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define which tools the agent can use to answer user queries\n",
    "search = SerpAPIWrapper()\n",
    "tools = [\n",
    "    Tool(\n",
    "        name = \"Search\",\n",
    "        func=search.run,\n",
    "        description=\"useful for when you need to answer questions about current events\"\n",
    "    )\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2e7a075c",
   "metadata": {},
   "source": [
    "## Prompt Template\n",
    "\n",
    "This instructs the agent on what to do. Generally, the template should incorporate:\n",
    "    \n",
    "- `tools`: which tools the agent has access and how and when to call them.\n",
    "- `intermediate_steps`: These are tuples of previous (`AgentAction`, `Observation`) pairs. These are generally not passed directly to the model, but the prompt template formats them in a specific way.\n",
    "- `input`: generic user input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "339b1bb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up the base template\n",
    "template = \"\"\"Answer the following questions as best you can, but speaking as a pirate might speak. You have access to the following tools:\n",
    "\n",
    "{tools}\n",
    "\n",
    "Use the following format:\n",
    "\n",
    "Question: the input question you must answer\n",
    "Thought: you should always think about what to do\n",
    "Action: the action to take, should be one of [{tool_names}]\n",
    "Action Input: the input to the action\n",
    "Observation: the result of the action\n",
    "... (this Thought/Action/Action Input/Observation can repeat N times)\n",
    "Thought: I now know the final answer\n",
    "Final Answer: the final answer to the original input question\n",
    "\n",
    "Begin! Remember to speak as a pirate when giving your final answer. Use lots of \"Arg\"s\n",
    "\n",
    "Question: {input}\n",
    "{agent_scratchpad}\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fd969d31",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set up a prompt template\n",
    "class CustomPromptTemplate(BaseChatPromptTemplate):\n",
    "    # The template to use\n",
    "    template: str\n",
    "    # The list of tools available\n",
    "    tools: List[Tool]\n",
    "    \n",
    "    def format_messages(self, **kwargs) -> str:\n",
    "        # Get the intermediate steps (AgentAction, Observation tuples)\n",
    "        # Format them in a particular way\n",
    "        intermediate_steps = kwargs.pop(\"intermediate_steps\")\n",
    "        thoughts = \"\"\n",
    "        for action, observation in intermediate_steps:\n",
    "            thoughts += action.log\n",
    "            thoughts += f\"\\nObservation: {observation}\\nThought: \"\n",
    "        # Set the agent_scratchpad variable to that value\n",
    "        kwargs[\"agent_scratchpad\"] = thoughts\n",
    "        # Create a tools variable from the list of tools provided\n",
    "        kwargs[\"tools\"] = \"\\n\".join([f\"{tool.name}: {tool.description}\" for tool in self.tools])\n",
    "        # Create a list of tool names for the tools provided\n",
    "        kwargs[\"tool_names\"] = \", \".join([tool.name for tool in self.tools])\n",
    "        formatted = self.template.format(**kwargs)\n",
    "        return [HumanMessage(content=formatted)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "798ef9fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = CustomPromptTemplate(\n",
    "    template=template,\n",
    "    tools=tools,\n",
    "    # This omits the `agent_scratchpad`, `tools`, and `tool_names` variables because those are generated dynamically\n",
    "    # This includes the `intermediate_steps` variable because that is needed\n",
    "    input_variables=[\"input\", \"intermediate_steps\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ef3a1af3",
   "metadata": {},
   "source": [
    "## Output Parser\n",
    "\n",
    "The output parser is responsible for parsing the LLM output into `AgentAction` and `AgentFinish`. This usually depends heavily on the prompt used.\n",
    "\n",
    "This is where you can change the parsing to do retries, handle whitespace, etc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7c6fe0d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CustomOutputParser(AgentOutputParser):\n",
    "    \n",
    "    def parse(self, llm_output: str) -> Union[AgentAction, AgentFinish]:\n",
    "        # Check if agent should finish\n",
    "        if \"Final Answer:\" in llm_output:\n",
    "            return AgentFinish(\n",
    "                # Return values is generally always a dictionary with a single `output` key\n",
    "                # It is not recommended to try anything else at the moment :)\n",
    "                return_values={\"output\": llm_output.split(\"Final Answer:\")[-1].strip()},\n",
    "                log=llm_output,\n",
    "            )\n",
    "        # Parse out the action and action input\n",
    "        regex = r\"Action: (.*?)[\\n]*Action Input:[\\s]*(.*)\"\n",
    "        match = re.search(regex, llm_output, re.DOTALL)\n",
    "        if not match:\n",
    "            raise ValueError(f\"Could not parse LLM output: `{llm_output}`\")\n",
    "        action = match.group(1).strip()\n",
    "        action_input = match.group(2)\n",
    "        # Return the action and action input\n",
    "        return AgentAction(tool=action, tool_input=action_input.strip(\" \").strip('\"'), log=llm_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d278706a",
   "metadata": {},
   "outputs": [],
   "source": [
    "output_parser = CustomOutputParser()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "170587b1",
   "metadata": {},
   "source": [
    "## Set up LLM\n",
    "\n",
    "Choose the LLM you want to use!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f9d4c374",
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatOpenAI(temperature=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "caeab5e4",
   "metadata": {},
   "source": [
    "## Define the stop sequence\n",
    "\n",
    "This is important because it tells the LLM when to stop generation.\n",
    "\n",
    "This depends heavily on the prompt and model you are using. Generally, you want this to be whatever token you use in the prompt to denote the start of an `Observation` (otherwise, the LLM may hallucinate an observation for you)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "34be9f65",
   "metadata": {},
   "source": [
    "## Set up the Agent\n",
    "\n",
    "We can now combine everything to set up our agent"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "9b1cc2a2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# LLM chain consisting of the LLM and a prompt\n",
    "llm_chain = LLMChain(llm=llm, prompt=prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e4f5092f",
   "metadata": {},
   "outputs": [],
   "source": [
    "tool_names = [tool.name for tool in tools]\n",
    "agent = LLMSingleActionAgent(\n",
    "    llm_chain=llm_chain, \n",
    "    output_parser=output_parser,\n",
    "    stop=[\"\\nObservation:\"], \n",
    "    allowed_tools=tool_names\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aa8a5326",
   "metadata": {},
   "source": [
    "## Use the Agent\n",
    "\n",
    "Now we can use it!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "490604e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "653b1617",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
      "\u001b[32;1m\u001b[1;3mThought: Wot year be it now? That be important to know the answer.\n",
      "Action: Search\n",
      "Action Input: \"current population canada 2023\"\u001b[0m\n",
      "\n",
      "Observation:\u001b[36;1m\u001b[1;3m38,649,283\u001b[0m\u001b[32;1m\u001b[1;3mAhoy! That be the correct year, but the answer be in regular numbers. 'Tis time to translate to pirate speak.\n",
      "Action: Search\n",
      "Action Input: \"38,649,283 in pirate speak\"\u001b[0m\n",
      "\n",
      "Observation:\u001b[36;1m\u001b[1;3mBrush up on your “Pirate Talk” with these helpful pirate phrases. Aaaarrrrgggghhhh! Pirate catch phrase of grumbling or disgust. Ahoy! Hello! Ahoy, Matey, Hello ...\u001b[0m\u001b[32;1m\u001b[1;3mThat be not helpful, I'll just do the translation meself.\n",
      "Final Answer: Arrrr, thar be 38,649,283 scallywags in Canada as of 2023.\u001b[0m\n",
      "\n",
      "\u001b[1m> Finished chain.\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'Arrrr, thar be 38,649,283 scallywags in Canada as of 2023.'"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "agent_executor.run(\"How many people live in canada as of 2023?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "adefb4c2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.1"
  },
  "vscode": {
   "interpreter": {
    "hash": "18784188d7ecd866c0586ac068b02361a6896dc3a29b64f5cc957f09c590acef"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
